import torch.nn as nn
import numpy as np
import torch
import copy
#from base_func import *
import torch.nn.functional as F
from torchvision.models import resnet18, resnet50

from sklearn.linear_model import LogisticRegression
from xgboost import XGBClassifier
from sklearn.svm import SVC
from sklearn.ensemble import RandomForestClassifier
from sklearn.neighbors import KNeighborsClassifier

def sigmoid(x):
    return 1 / (1 + torch.exp(-x))

def softplus(x):
    return torch.log(1 + torch.exp(x))

class encode_mean_std_pair(nn.Module):
    def __init__(self, input_dim, hidden_dim, dropout_rate):
        super(encode_mean_std_pair, self).__init__()
        self.encode_mean = nn.Sequential(
            nn.Linear(input_dim, hidden_dim), nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(hidden_dim, 1)
        )
        self.encode_std = nn.Sequential(
            nn.Linear(input_dim, hidden_dim), nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(hidden_dim, 1), nn.Softplus()
        )
        self.encode_b = nn.Sequential(
            nn.Linear(input_dim, hidden_dim), nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(hidden_dim, input_dim)
        )
    
    def forward(self, x):
        mean = self.encode_mean(x)
        std = self.encode_std(x)
        b = self.encode_b(x)
        return mean, std, b


class ExpertAttention(nn.Module):
    def __init__(self, feature_dim, n_way, experts, d_model = 256, n_heads = 4):
        super(ExpertAttention, self).__init__()
        self.feature_dim = feature_dim
        self.n_way = n_way
        self.experts = experts
        self.d_model = d_model
        self.n_heads = n_heads
        
        # query vector
        self.query_proj = nn.Linear(feature_dim, d_model)
        
        # key/value vector
        self.key_proj = nn.Linear(feature_dim, d_model)
        self.value_proj = nn.Linear(feature_dim, d_model)
        
        # multi-head attention
        self.multihead_attn = nn.MultiheadAttention(
            embed_dim=d_model, 
            num_heads=n_heads, 
            batch_first=True
        )
        
        # output layer
        self.output_layer = nn.Sequential(
            nn.Linear(d_model, 64),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(64, 1)
        )
        
        self.softmax = nn.Softmax(dim=1)
        
    def forward(self, support_features, query_features=None):
        """
        support_features: [batch_size, experts, feature_dim]
        query_features: [batch_size, feature_dim * n_way] (可选)
        """
        batch_size = support_features.size(0)
        #print(support_features.shape)
        
        # use mean value of support set
        if query_features is None:
            query_features = support_features.mean(dim=1).view(batch_size, -1)
        #print(query_features.shape)
        
        # query vector generation [batch_size, d_model]
        query = self.query_proj(query_features).unsqueeze(1)  # [batch_size, 1, d_model]
        
        # key/value vector generation [batch_size, experts, d_model]
        keys = self.key_proj(support_features)  # [batch_size, experts, d_model]
        values = self.value_proj(support_features)  # [batch_size, experts, d_model]
        
        # multi-head
        attn_output, attn_weights = self.multihead_attn(
            query, keys, values
        )
        
        # importance score of each expert
        importance_scores = self.output_layer(attn_output.squeeze(1))  # [batch_size, 1]
        importance_scores = importance_scores.view(batch_size, -1)  # [batch_size, 1]
        
        # normalize
        attention_weights = self.softmax(importance_scores)  # [batch_size, experts]
        
        return attention_weights#, attn_weights



class MoE_RIM(nn.Module):
    def __init__(self, args, base_images, base_labels):
        super(MoE_RIM, self).__init__()
        self.experts = args['experts']
        self.n_way = args['n_way']
        self.n_shot = args['n_shot']
        self.n_query = args['n_query']
        self.n_gaus = args['n_gaus']
        self.n_base = args['n_base']
        self.edge_dim = args['edge_dim']
        self.lambda1 = args['lambda1']
        self.lambda2 = args['lambda2']
        self.device = args['device']
        self.classifier = args['classifier']
        self.ImageSize = args['ImageSize']
        self.overlap = args['overlap']

        # Feature extractor (pretrained ResNet18)
        self.feature_extractor = resnet18(pretrained=True)

        #You can use other Feature extracctor
        #self.feature_extractor = resnet50(pretrained=True)
        self.feature_extractor = nn.Sequential(*list(self.feature_extractor.children())[:-1])  # Remove final FC layer
        self.feature_dim = 512  # Output dimension of ResNet18 before final FC




        # Initialize expert RIM models
        self.rim_experts = nn.ModuleList()
        for i in range(self.experts):
            # Get base features for this expert
            expert_base_images = base_images[i*(self.n_base - int(self.n_base * self.overlap)):(i+1)*self.n_base - i * (int(self.n_base * self.overlap))]
            print(i*(self.n_base - int(self.n_base * self.overlap)))
            print((i+1)*self.n_base - i * (int(self.n_base * self.overlap)))
            print("=======================================")
            expert_base_features = self._extract_features(expert_base_images)
            
            # Create RIM expert
            rim = RIM(
                feature_dim=self.feature_dim,
                base_num=self.n_base,
                edge_dim=self.edge_dim,
                base_nodes=expert_base_features,
                n_gaus=self.n_gaus,
                omega = args['omega'],
                last_dense=128,
                out_dim=64,
                g_dim=args['g_dim'],
                dropout=0.1
            )
            self.rim_experts.append(rim)

        # Gating network
        

        self.gate = nn.Sequential(
            nn.Linear(self.feature_dim * (self.n_way), 256),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(256, self.experts),
            nn.Softmax(dim=1)
        )

        self.gate = ExpertAttention(self.feature_dim, self.n_way, self.experts, d_model = 256, n_heads = 4)


    def _extract_features(self, images):
        """Extract features from images using the feature extractor"""
        # images shape: [n_classes, n_images, 3, iamge_size, image_size]
        n_classes, n_images = images.shape[:2]
        images = images.view(-1, 3, self.ImageSize, self.ImageSize)  # Flatten first two dimensions

        with torch.no_grad():
            features = self.feature_extractor(images)
        #features = self.feature_extractor(images)

        features = features.squeeze(-1).squeeze(-1)  # Remove spatial dimensions
        return features.view(n_classes, n_images, -1).mean(dim=1)  # Average over images
    
    def NN_classify(self, support_output, labels, query_output):
        with torch.enable_grad():  
            # Prepare data
            X_train = support_output.reshape(-1, self.feature_dim)
            y_train = labels#torch.from_numpy(labels).long().to(self.device)

            # Initialize and train
            classifier = SimpleNN(self.feature_dim, self.n_way).to(self.device)
            classifier.train()
            criterion = nn.CrossEntropyLoss()
            optimizer = torch.optim.Adam(classifier.parameters(), lr=0.0005)

            # Train for a few epochs
            for epoch in range(50):
                optimizer.zero_grad()
                outputs = classifier(X_train)
                loss = criterion(outputs, y_train)
                loss.backward(retain_graph = True)
                optimizer.step()

            # Get query logits (no softmax needed as we want logits)
            with torch.no_grad():
                query_logits = classifier(query_output.reshape(-1, self.feature_dim))
            return query_logits

    def forward(self, support_set, query_set, support_labels = None, mode="Train"):
        """
        Args:
            support_set: torch.Size([batch_size, n_way, n_shot, 3, 84, 84])
            query_set: torch.Size([batch_size, n_way, n_query, 3, 84, 84])
            mode: "Train" or "Test"
        Returns:
            query_logits: torch.Size([batch_size, n_way * n_query, n_way])
        """
        batch_size = support_set.size(0)
        if self.device != 'cpu':
            support_set = support_set.cuda()
            query_set = query_set.cuda()
        
        # Extract features for support and query sets
        support_features = self._extract_features_from_set(support_set)  # [batch, n_way, n_shot, feature_dim]
        query_features = self._extract_features_from_set(query_set)      # [batch, n_way, n_query, feature_dim]
        
        # Prepare gating input (average support features per class)
        gate_input = support_features.mean(dim=2)  # [batch, n_way, feature_dim]
        
        # Get expert weights
        expert_weights = self.gate(gate_input)  # [batch, experts]
        
        # Prepare labels for logistic regression (0 to n_way-1 for each class)
        if support_labels == None:
            labels = torch.arange(self.n_way).repeat_interleave(self.n_shot).repeat(self.n_gaus).numpy()
        else:
            labels = torch.arange(self.n_way).repeat_interleave(self.n_shot).repeat(self.n_gaus).numpy()
        
        all_query_logits = []
        all_query_klg = []
        all_query_klb = []
        for b in range(batch_size):
            # Process each batch item separately
            batch_support = support_features[b]  # [n_way, n_shot, feature_dim]
            batch_query = query_features[b]      # [n_way, n_query, feature_dim]
            if support_labels != None:
                labels = support_labels[b]
                labels = labels.repeat(self.n_gaus)
            
            # Flatten support and query sets
            flat_support = batch_support.view(-1, self.feature_dim)  # [n_way*n_shot, feature_dim]
            flat_query = batch_query.view(-1, self.feature_dim)      # [n_way*n_query, feature_dim]
            
            expert_outputs = []
            expert_klg = []
            expert_klb = []

            for i, expert in enumerate(self.rim_experts):
                # Process support samples through expert to get training features
                support_output, _, _ = expert(flat_support, mode="support", n_gaus=self.n_gaus)
                #print(support_output.shape)
                # support_output shape: [n_gaus, n_way*n_shot, feature_dim]
                
                # Process query samples through expert
                query_output, kl_g, kl_b = expert(flat_query, mode="query", n_gaus=1)
                query_output = query_output.squeeze()
                # query_output shape: [n_way*n_query, feature_dim]
                if self.classifier == 'NN':
                    query_logits = self.NN_classify(support_output, labels, query_output)
                elif self.classifier == 'LogisticRegression':
                    # Train logistic regression classifier on support features
                    classifier = LogisticRegression(max_iter=2000)
                    # Use all gaussian samples for training
                    X_train = support_output.reshape(-1, self.feature_dim).detach().cpu().numpy()
                    y_train = labels.detach().cpu().numpy()
                    classifier.fit(X_train, y_train)
                
                    # Predict on query features
                    query_logits = torch.from_numpy(classifier.predict_proba(
                        query_output.detach().cpu().numpy())).float().to(self.device)
                    #print(query_logits[0])
                    # query_logits shape: [n_way*n_query, n_way]
                elif self.classifier == 'XGBoost':
                    # Train XGBoost classifier
                    classifier = XGBClassifier(eval_metric='mlogloss')
                    X_train = support_output.reshape(-1, self.feature_dim).detach().cpu().numpy()
                    y_train = labels
                    classifier.fit(X_train, y_train)

                    # Predict probabilities
                    query_probs = classifier.predict_proba(query_output.detach().cpu().numpy())
                    query_logits = torch.from_numpy(query_probs).float().to(self.device)
                elif self.classifier == 'SVM':
                    # Train SVM with probability estimates
                    classifier = SVC(probability=True, kernel='linear', random_state=42)
                    X_train = support_output.reshape(-1, self.feature_dim).detach().cpu().numpy()
                    y_train = labels.detach().cpu().numpy()
                    classifier.fit(X_train, y_train)

                    # Predict probabilities
                    query_probs = classifier.predict_proba(query_output.detach().cpu().numpy())
                    query_logits = torch.from_numpy(query_probs).float().to(self.device)
                elif self.classifier == 'RF':
                    # Train Random Forest classifier
                    classifier = RandomForestClassifier(n_estimators=100, random_state=42)
                    X_train = support_output.reshape(-1, self.feature_dim).detach().cpu().numpy()
                    y_train = labels.detach().cpu().numpy()
                    classifier.fit(X_train, y_train)

                    # Predict probabilities (logits)
                    query_probs = classifier.predict_proba(query_output.detach().cpu().numpy())
                    query_logits = torch.from_numpy(query_probs).float().to(self.device)
                elif self.classifier == 'KNN':
                    # Train KNN classifier
                    classifier = KNeighborsClassifier(n_neighbors=3)
                    X_train = support_output.reshape(-1, self.feature_dim).detach().cpu().numpy()
                    y_train = labels.detach().cpu().numpy()
                    classifier.fit(X_train, y_train)

                    # Predict probabilities
                    query_probs = classifier.predict_proba(query_output.detach().cpu().numpy())
                    query_logits = torch.from_numpy(query_probs).float().to(self.device)

                
                expert_outputs.append(query_logits)
                expert_klg.append(kl_g)
                expert_klb.append(kl_b)
            
            # Combine expert outputs using gating weights
            expert_outputs = torch.stack(expert_outputs)  # [experts, n_way*n_query, n_way]
            weighted_output = torch.sum(expert_outputs * expert_weights[b].view(-1, 1, 1), dim=0)
            all_query_logits.append(weighted_output)

            expert_klg = torch.stack(expert_klg)  # [experts, n_way*n_query, n_way]
            
            weighted_klg = torch.sum(expert_klg * expert_weights[b], dim=0)
            all_query_klg.append(weighted_klg)


            expert_klb = torch.stack(expert_klb)  # [experts, n_way*n_query, n_way]
            weighted_klb = torch.sum(expert_klb * expert_weights[b], dim=0)
            all_query_klb.append(weighted_klb)
        
        # Stack batch outputs
        query_logits = torch.stack(all_query_logits)  # [batch_size, n_way*n_query, n_way]
        klg = torch.sum(torch.stack(all_query_klg), dim = 0)
        klb = torch.sum(torch.stack(all_query_klb), dim = 0)

        #print(query_logits.shape)
        return query_logits, klg, klb
    
    def _extract_features_from_set(self, x):
        """Extract features from a set of images (support or query)"""
        # x shape: [batch, n_way, n_shot/n_query, 3, image_size, image_size]
        batch_size, n_way, n_samples = x.shape[:3]
        x = x.view(-1, 3, self.ImageSize, self.ImageSize)  # Flatten first three dimensions
        
        with torch.no_grad():
            features = self.feature_extractor(x)
        
        features = features.squeeze(-1).squeeze(-1)  # Remove spatial dimensions
        return features.view(batch_size, n_way, n_samples, -1)  # Restore original shape





class RIM(nn.Module):
    def __init__(self, feature_dim, base_num, edge_dim, base_nodes, n_gaus, omega = 0.5, last_dense = 128, out_dim = 64, g_dim = 512, dropout = 0.1):
        """
        feature_dim:input feature dimension
        base_num:number of base classes
        base_nodes:base_num * feature_dim, base node features
        n_gaus: number of Gaussian graph samples
        """
        super(RIM, self).__init__()
  
        self.omega = omega
        self.feature_dim = feature_dim
        self.base_num = base_num
        self.dropout_rate = dropout
        


        self.edge_dim = edge_dim
        self.base_nodes = base_nodes.cuda()#torch.from_numpy(base_nodes).cuda()
        self.n_gaus = n_gaus
        self.g_dim = g_dim
        self.out_dim = out_dim
        self.last_dense = last_dense

        self.x_linear = nn.Sequential(
            nn.Linear(feature_dim, feature_dim),
            nn.Dropout(p=0.1)
        )
        
        # prior
        self.prior_enc = encode_mean_std_pair(self.edge_dim,self.edge_dim, self.dropout_rate)
        self.prior_mij = nn.Linear(self.edge_dim, 1)

        # post
        self.post_enc = encode_mean_std_pair(self.edge_dim, self.edge_dim,self.dropout_rate)
        self.post_mean_approx_g = nn.Linear(self.edge_dim, 1)
        self.post_std_approx_g = nn.Sequential(nn.Linear(self.edge_dim, 1), nn.Softplus())


        #from edge embeddings to graph
        self.post_emb_to_graph = nn.Sequential(
            nn.Linear(self.edge_dim, self.edge_dim), nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(self.edge_dim, 1), nn.ReLU()
        )


        #edge embedding generation
        self.gen_edge_emb = nn.Sequential(
            nn.Linear(self.feature_dim * 2, self.edge_dim), nn.ReLU(),
            nn.Dropout(self.dropout_rate),
            nn.Linear(self.edge_dim, self.edge_dim)
        )
        self.relation_encoder = nn.Sequential(
            nn.Linear(edge_dim, edge_dim * 4),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(edge_dim * 4, edge_dim),
            nn.LayerNorm(edge_dim)
        )

        #classification layer of training phase
        self.out_layer = nn.Sequential(
            nn.Linear(self.g_dim, self.last_dense),
            #nn.ReLU(),
            nn.Dropout(p=0.2),
            nn.Linear(self.last_dense, self.out_dim)
        )

        #classification layer for few-shot tasks
        self.out_layer_fs = nn.Sequential(
            nn.Linear(self.g_dim, self.last_dense),
            #nn.ReLU(),
            nn.Dropout(p=0.2),
            nn.Linear(self.last_dense, 5)
        )

        #
        self.graph_norm = nn.Sequential(
            nn.Linear(1, 1),
            nn.Tanh()
        )
        

        self.base_update = nn.Sequential(
            nn.Linear(feature_dim, feature_dim),
            nn.ReLU(),
            nn.Dropout(dropout)
        )


    def forward(self, x, mode = "Train", n_gaus = 1000):
        self.n_gaus = n_gaus

        #Convert input features to torch array format
        x_node_emb = x.clone()
        
        #x_node0 = x_node_emb
        base_nodes = self.base_nodes

        node_pairs = torch.zeros(x_node_emb.shape[0], self.base_num, self.feature_dim * 2).cuda()
        
        for i in range(self.base_num):
            for j in range(x_node_emb.shape[0]):
                node_pairs[j][i] = torch.cat([base_nodes[i], x_node_emb[j]], dim = 0) #for edge embedding calculation
        x_node0 = x_node_emb

        #node2edge
        edge_emb = self.gen_edge_emb(node_pairs) #edge embedding generation(b*base_num*edge_dim)
        edge_emb = self.relation_encoder(edge_emb)

        


        input4prior = edge_emb.clone()
        input4post = edge_emb.clone()

        #prior graph
        prior_mean, prior_std, prior_b = self.prior_enc(input4prior)
        prior_mij = self.prior_mij(prior_b)
        prior_mij = 0.4 * sigmoid(prior_mij)
        prior_mij = prior_mij.squeeze(-1) #batch*base_num


        #post graph
        post_mean_g,post_std_g,post_b = self.post_enc(input4post)
        post_mean_approx_g = self.post_mean_approx_g(post_b) # (W,F_g)
        post_std_approx_g = self.post_std_approx_g(post_b)   # (W,F_g)
        post_mean_approx_g = post_mean_approx_g.squeeze(-1)
        post_std_approx_g = post_std_approx_g.squeeze(-1)


        # estimate post mij for Binomial Dis
        nij = softplus(post_mean_approx_g) + 0.01
        nij_ = 2.0 * nij * post_std_approx_g.pow(2)
        post_mij = 0.5 * (1.0 + nij_ - torch.sqrt(nij_.pow(2) + 1))

        post_mean_g = post_mean_g.squeeze(-1)
        post_std_g = post_std_g.squeeze(-1)


        alpha_bars, alpha_tilde = self.sample_repara(post_mean_g, post_std_g, post_mij, self.n_gaus, mode)


        normalized_graphs = []
        for i in range(self.n_gaus):

            alpha_max = alpha_bars[i].max(dim=-1, keepdim=True)[0]
            alpha_min = alpha_bars[i].min(dim=-1, keepdim=True)[0]

            norm_graph = (2 * alpha_bars[i] - alpha_max - alpha_min) / (alpha_max - alpha_min + 1e-6)
            normalized_graphs.append(norm_graph)
        
        normalized_graphs = torch.stack(normalized_graphs)  # [n_gaus, batch, base_num]
        

        # regularization
        a1 = alpha_tilde * post_mean_g
        a2 = torch.sqrt(alpha_tilde) * post_std_g
        a3 = alpha_tilde * prior_mean.squeeze(-1)
        a4 = torch.sqrt(alpha_tilde) * prior_std.squeeze(-1)

        kl_g = self.kld_loss_gauss(a1, a2, a3, a4)
        kl_b = self.kld_loss_binomial_upper_bound(post_mij, prior_mij)

        # fusion feature generation
        fusion_features = []
        for i in range(self.n_gaus):
            weighted_base = torch.matmul(normalized_graphs[i], base_nodes)  # [batch, feature_dim]

            fused_feature = self.omega * weighted_base + (1 - self.omega) * x_node0
            fusion_features.append(fused_feature)
        
        fusion_features = torch.stack(fusion_features)  # [n_gaus, batch, feature_dim]

        return fusion_features, kl_g , kl_b 


    def sample_repara(self, mean, std, mij, n_gaus, mode):
        mean_alpha = mij
        std_alpha = torch.sqrt(mij*(1.0 - mij))

        eps = torch.FloatTensor(std.size()).normal_().cuda()
        alpha_tilde = eps * std_alpha + mean_alpha
        alpha_tilde = softplus(alpha_tilde)

        alpha_bars = []
        for _ in range(n_gaus):
            eps = torch.randn_like(std).cuda()
            s_ij = eps * std + mean
            alpha_bar = s_ij * alpha_tilde
            alpha_bars.append(alpha_bar)
        
        alpha_bars = torch.stack(alpha_bars)
        return alpha_bars, alpha_tilde
    
    def norm(self, x):
        return (x - torch.min(x)) / (torch.max(x) - torch.min(x))
    
    #kl loss of Gaussian distribution estimate
    def kld_loss_gauss(self, mean_post, std_post, mean_prior, std_prior, eps=1e-6): 
        kld_element = (2 * torch.log(std_prior + eps) - 2 * torch.log(std_post + eps) +
                       ((std_post).pow(2) + (mean_post - mean_prior).pow(2)) /
                       (std_prior + eps).pow(2) - 1)
        return 0.5 * torch.sum(torch.abs(kld_element))

    #ELBO
    def kld_loss_binomial_upper_bound(self, mij_post, mij_prior, eps=1e-6): 
        kld_element = mij_prior - mij_post + \
                       mij_post * (torch.log(mij_post+eps) - torch.log(mij_prior+eps))
        return torch.sum(torch.abs(kld_element))



class SimpleNN(nn.Module):
    def __init__(self, input_dim, num_classes):
        super(SimpleNN, self).__init__()
        self.classifier = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.ReLU(),
            nn.Linear(128, num_classes)
        )
    
    def forward(self, x):
        return self.classifier(x)